#include "MinimumJerkStrategy.h"
#include <MMM/MMMCore.h>

using namespace MMM;

MinimumJerkStrategy::MinimumJerkStrategy(const std::map<float, std::map<std::string, Eigen::Vector3f> > &labeledMarkerData, MotionPtr outputMotion, ModelPtr outputModel, const std::vector<std::string> &joints, const std::map<std::string, std::string> &markerMapping, const std::map<std::string, float> &markerWeights) :
    ConvertingStrategy(labeledMarkerData, outputMotion, outputModel, joints, markerMapping, markerWeights)

{
}

void MinimumJerkStrategy::cancel() {
    optimizer.force_stop();
}

float MinimumJerkStrategy::getCurrentTimestep() {
    return -1.0f;
}

void MinimumJerkStrategy::convert() {
    MMM_INFO << "Convert MinimumJerk" << std::endl;
    //MMM_ASSERT(inputMotion);

    auto sensor = inputMotion->getSensorByType<KinematicSensor>("Kinematic");
    ModelPoseSensorPtr poseSensor = inputMotion->getSensorByType<ModelPoseSensor>("ModelPose");
    timesteps = sensor->getTimesteps();
    if (frameRange > (int) timesteps.size()) {
        throw MMM::Exception::MMMException("Motion requires at least " + std::to_string(frameRange) + " frames to use MinimumJerkStrategy. Motion only has " + std::to_string(timesteps.size()) + " frames.");
    }

    auto frameToConfig = [&](float step)
    {

        KinematicSensorMeasurementPtr measure = boost::dynamic_pointer_cast<KinematicSensorMeasurement>(sensor->getMeasurement(step));

        ModelPoseSensorMeasurementPtr poseMeasure = boost::dynamic_pointer_cast<ModelPoseSensorMeasurement>(poseSensor->getMeasurement(step));

        VR_ASSERT(measure);
        VR_ASSERT(poseMeasure);
        Eigen::Vector3f pos = poseMeasure->getRootPosition();
        Eigen::VectorXf jointAngles = measure->getJointAngles();
        std::vector<double> config;
        config.reserve(jointAngles.rows() + 6);
        config.push_back(pos(0));
        config.push_back(pos(1));
        config.push_back(pos(2));
        Eigen::Vector3f rootRot = poseMeasure->getRootRotation();
        if (rootRot[0] > M_PI) rootRot[0] -= 2 * M_PI;
        else if (rootRot[0] < -M_PI) rootRot[0] += 2 * M_PI;
        if (rootRot[1] > M_PI) rootRot[1] -= 2 * M_PI;
        else if (rootRot[1] < -M_PI) rootRot[1] += 2 * M_PI;
        if (rootRot[2] > M_PI) rootRot[2] -= 2 * M_PI;
        else if (rootRot[2] < -M_PI) rootRot[2] += 2 * M_PI;
        config.push_back(rootRot(0));
        config.push_back(rootRot(1));
        config.push_back(rootRot(2));

        for(int i = 0; i <jointAngles.rows(); i++)
        {
            config.push_back(jointAngles(i));
        }
        return config;
    };
    for(auto & step: timesteps)
    {
        optimizedFrames[step] = frameToConfig(step);
    }

    currentFrameEnd = frameRange;
    currentFrameStart = 0;
    while(true)
    {
        lastPrintLoss = -1;
        std::cout << "start frame: " << currentFrameStart << " end: " << currentFrameEnd << std::endl;
        dimension = (currentFrameEnd-currentFrameStart) * frameDimension;
        // Build initial configuration for optimization
        std::vector<double> configuration(dimension, 0.0);
        std::vector<double> initialConfig;
        // Initialize optimization
        optimizer = nlopt::opt(nloptAlgorithm, dimension);
        optimizer.set_min_objective(ConvertingStrategy::objectiveFunctionWrapperStatic, this);
        optimizer.set_ftol_rel(0.001);
//        optimizer.set_ftol_abs(0.0001);
        optimizer.set_maxtime(600);




//        MMM_INFO << "Using input motion" << std::endl;

        initialConfig.reserve(dimension);

        if(sensor)
        {

            Eigen::VectorXf jointAngles;

            for(int i = currentFrameStart; i < currentFrameEnd; i++)
            {
                float step = timesteps.at(i);
//                std::cout << "step " << step << std::endl;
                auto add = optimizedFrames.at(step);
                initialConfig.insert(initialConfig.end(),add.begin(), add.end());
            }
            Eigen::MatrixXd init = Eigen::Map<Eigen::Matrix<double, -1,-1,Eigen::RowMajor>>(initialConfig.data(), timesteps.size(),jointAngles.rows()+6);
//            MMM_INFO << "init:\n" << init << std::endl;
//            MMM_INFO << "frames: " << labeledMarkerData.size() << " joints: " << frameDimension << std::endl << " input motion frame size: " << jointAngles.rows() << " timesteps: " << timesteps.size() << std::endl;
//            MMM_INFO << "Setting initial config of size " << initialConfig.size() << " needed size: " << dimension << std::endl;
//            MMM_ASSERT(configuration.size() == initialConfig.size());
            configuration = initialConfig;
        }
        else
        {
            MMM_WARNING << "No kinematic sensor in input motion!" << std::endl;
        }

        setOptimizationBounds(optimizer);

        // Run optimization
        MMM_INFO << "Starting optimization of whole motion (" << labeledMarkerData.size() << " frames)..." << std::endl;
        double objectiveValue;
        try {
            nlopt::result resultCode = optimizer.optimize(configuration, objectiveValue);

            MMM_INFO << "Optimization finished with code " << resultCode << "." << std::endl;
        }
        catch (nlopt::roundoff_limited&e) {
            MMM_INFO << "Optimization finished by throwing nlopt::roundoff_limited (the result should be usable)." << std::endl;
        }
        int frameNum = currentFrameStart;
        for(int i = 0; i < currentFrameEnd-currentFrameStart; i++)
        {
            std::vector<double> frameConfiguration(configuration.begin() + i * frameDimension, configuration.begin() + (i + 1) * frameDimension);
//            std::cout << "Updating " << timesteps.at(frameNum) << " with " << frameConfiguration.at(0) << std::endl;
            if(frameNum >= (int)timesteps.size())
                break;
            optimizedFrames[timesteps.at(frameNum)] = frameConfiguration;
            frameNum++;
        }
        if(currentFrameEnd >= (int)timesteps.size()-1)
        {
            break;
        }
        currentFrameStart += frameRange*(1.0f-overlapPercent);
        currentFrameEnd = std::min<int>(currentFrameStart+frameRange, timesteps.size());
    }
    double maxJerk, avgJerk;
    std::tie(maxJerk,avgJerk) = calculateMaxAndAverageJerk(optimizedFrames);
    MMM_INFO << "Max jerk: " << maxJerk << " avg jerk: " << avgJerk << std::endl;
//    MMM_INFO << "Copying results to motion" << std::endl;
    int frameNum = 0;
    for (const auto &labeledMarker : labeledMarkerData) {
        std::vector<double> frameConfiguration = optimizedFrames.at(labeledMarker.first);

        float timestep = labeledMarker.first;
        Eigen::Vector3f rootPos = Eigen::Vector3f::Zero();
        Eigen::Vector3f rootRot = Eigen::Vector3f::Zero();
        // calculate ModelPoseSensorMeasurement
        rootPos[0] += frameConfiguration[0]; rootPos[1] += frameConfiguration[1]; rootPos[2] += frameConfiguration[2];
        rootRot[0] += frameConfiguration[3]; rootRot[1] += frameConfiguration[4]; rootRot[2] += frameConfiguration[5];
        if (rootRot[0] > M_PI) rootRot[0] -= 2 * M_PI;
        else if (rootRot[0] < -M_PI) rootRot[0] += 2 * M_PI;
        if (rootRot[1] > M_PI) rootRot[1] -= 2 * M_PI;
        else if (rootRot[1] < -M_PI) rootRot[1] += 2 * M_PI;
        if (rootRot[2] > M_PI) rootRot[2] -= 2 * M_PI;
        else if (rootRot[2] < -M_PI) rootRot[2] += 2 * M_PI;

        ModelPoseSensorMeasurementPtr modelPoseSensorMeasurement(new ModelPoseSensorMeasurement(timestep, rootPos, rootRot));
        outputModelPoseSensor->addSensorMeasurement(modelPoseSensorMeasurement);

        // calculate KinematicSensorMeasurement
        Eigen::VectorXf jointValues(joints.size());
        for (int i = 0; i < jointValues.rows(); ++i) {
            jointValues[i] = frameConfiguration[6 + i];
        }

        KinematicSensorMeasurementPtr kinematicSensorMeasurement(new KinematicSensorMeasurement(timestep, jointValues));
        outputKinematicSensor->addSensorMeasurement(kinematicSensorMeasurement);

        frameNum++;
    }
}


double MinimumJerkStrategy::objectiveFunction(const std::vector<double> &configuration, std::vector<double> &grad) {
    if (!grad.empty()) {
        MMM_ERROR << "NloptConverter: Gradient computation not supported!" << std::endl;
        return 0.0;
    }

    if (configuration.size() != dimension) {
        MMM_ERROR << "NloptConverter: x has wrong number of dimensions (" << configuration.size() << ")!" << std::endl;
        return 0.0;
    }
    //    std::cout << "Starting objective func" << std::endl;
    double totalSumDistanceSquares = 0.0;
    std::vector<double> currentPosRot(6, 0.0);

//    int frameNum = currentFrameStart;
    std::unique_ptr<Eigen::VectorXd> previousJointPos, previousJointVel, previousJointAcc;
    double jerkSum = 0.0;
    //    std::vector<double> previousJointData;
    float previousTimestep = 0.0f;
    double maxJerk = 0.0;
    int maxJerkIndex = -1;
    size_t i = 0;
    auto currentStartFrame = std::max(0,currentFrameStart-3);
    auto currentEndFrame = std::min<int>(currentFrameEnd+3, timesteps.size());
    for(size_t frameNum = currentStartFrame; frameNum < (size_t)currentEndFrame; frameNum++) {

        auto timestep = timesteps.at(frameNum);
        auto &labeledMarker = labeledMarkerData.at(timestep);

        std::vector<double> frameConfiguration;
        if(frameNum < (size_t)currentFrameStart || frameNum >= (size_t)currentFrameEnd)
        {
            frameConfiguration = optimizedFrames.at(timestep);
        }
        else
        {
            frameConfiguration = std::vector<double>(configuration.begin() + i * frameDimension, configuration.begin() + (i + 1) * frameDimension);
            i++;
        }
//        std::vector<double> frameConfiguration(configuration.begin() + frameNum * frameDimension, configuration.begin() + (frameNum + 1) * frameDimension);
        //        for (int i = 0; i < 6; ++i) {
        //            currentPosRot[i] = frameConfiguration[i];
        //            frameConfiguration[i] = currentPosRot[i];
        //        }
//        Eigen::VectorXd vec = Eigen::Map<Eigen::VectorXd>(frameConfiguration.data(), frameConfiguration.size());
        setOutputModelConfiguration(frameConfiguration);
        //        std::cout << "Vec:\n" << vec.head(3) << std::endl;
        totalSumDistanceSquares += calculateMarkerDistancesSquaresSum(labeledMarker);
        auto tDelta = timestep - previousTimestep;
        Eigen::VectorXd jointPos = Eigen::Map<Eigen::VectorXd>(frameConfiguration.data(), frameConfiguration.size());
        Eigen::VectorXd orientation = jointPos.block<3,1>(3,0);
        jointPos << jointPos.head(3), jointPos.tail(jointPos.rows()-6); // remove orientation since it cannot be derived like in the following
        jointPos.head(3) /= 100; // make position smaller to fit to radian for optimization
        if(previousJointPos)
        {
            if(previousJointPos->rows() != (long)frameConfiguration.size())
            {
                MMM_ERROR << "Different joint vector sizes!" << std::endl;
                throw std::runtime_error("Different joint vector sizes!");
            }

            Eigen::VectorXd jointVel = (jointPos - *previousJointPos)/tDelta;
            if(previousJointVel)
            {
                Eigen::VectorXd jointAcc = (jointVel - *previousJointVel)/tDelta;
                if(previousJointAcc)
                {
                    Eigen::VectorXd jointJerk = (jointAcc-*previousJointAcc)/tDelta;
                    if(std::abs(jointJerk.maxCoeff()) > maxJerk)
                    {
                        maxJerk = std::abs(jointJerk.maxCoeff());
                        maxJerkIndex = frameNum;
                    }

                    jerkSum += jointJerk.squaredNorm();
                }
                else
                {
                    previousJointAcc.reset(new Eigen::VectorXd());
                }
                *previousJointAcc = jointAcc;
            }
            else
            {
                previousJointVel.reset(new Eigen::VectorXd());
            }
            *previousJointVel = jointVel;



        }
        else
        {
            previousJointPos.reset(new Eigen::VectorXd());
        }
        previousTimestep = timestep;
        *previousJointPos = jointPos;

    }
    float currentFrameRange =currentEndFrame - currentStartFrame;
    jerkSum *= 0.000001; // adjustment
    auto loss = totalSumDistanceSquares + jerkSum;
    if(loss < lastPrintLoss*0.99 || lastPrintLoss < 0)
    {
//        std::cout << "loss function from " << currentStartFrame << " to " << currentEndFrame << std::endl;
        std::cout << "current squared marker distance: " << totalSumDistanceSquares/currentFrameRange << " jerkSum: " << jerkSum/currentFrameRange << " current maxJerk: " << maxJerk << " at frame " << maxJerkIndex << std::endl;
        lastPrintLoss = loss;
    }
    return loss;
}


void MinimumJerkStrategy::setOptimizationBounds(nlopt::opt& optimizer) const {
    // Some algorithms cannot handle unconstraint components (i.e. upper/lower limit of +/- infinity)
    const double frame0PositionLowerBound = -10000.0, frame0PositionUpperBound = 10000.0;  // 10m
//    const double positionMaxChange = 100.0;  // 10cm

    const double frame0RotationLowerBound = -M_PI, frame0RotationUpperBound = M_PI;
//    const double rotationMaxChange = 0.2;

    std::vector<double> lowerBounds, upperBounds;

    for (int frame = 0; frame < currentFrameEnd-currentFrameStart; ++frame) {
        for (int i = 0; i < 2; ++i) {  // translation & rotation vectors
            for (int j = 0; j < 3; ++j) {
                //                if (frame) {
                //                    lowerBounds.push_back(i ? -rotationMaxChange : -positionMaxChange);
                //                    upperBounds.push_back(i ? rotationMaxChange : positionMaxChange);
                //                } else {
                //                    lowerBounds.push_back(i ? frame0RotationLowerBound : frame0PositionLowerBound);
                //                    upperBounds.push_back(i ? frame0RotationUpperBound : frame0PositionUpperBound);
                //                }

                lowerBounds.push_back(i ? frame0RotationLowerBound : frame0PositionLowerBound);
                upperBounds.push_back(i ? frame0RotationUpperBound : frame0PositionUpperBound);
            }
        }

        for (auto jointName : joints) {
            JointInfo jointInfo = outputModel->getModelNode(jointName)->joint;
            lowerBounds.push_back(jointInfo.limitLo);
            upperBounds.push_back(jointInfo.limitHi);
        }
    }

    optimizer.set_lower_bounds(lowerBounds);
    optimizer.set_upper_bounds(upperBounds);
}

int MinimumJerkStrategy::getFrameRange() const
{
    return frameRange;
}

void MinimumJerkStrategy::setFrameRange(int value)
{
    frameRange = value;
}

float MinimumJerkStrategy::getOverlapPercent() const
{
    return overlapPercent;
}

void MinimumJerkStrategy::setOverlapPercent(float value)
{
    overlapPercent = value;
}

MotionPtr MinimumJerkStrategy::getInputMotion() const
{
    return inputMotion;
}

void MinimumJerkStrategy::setInputMotion(const MotionPtr &value)
{
    inputMotion = value;
}
